// STL
#include <cstdlib>                                // std::atoi(), std::system()
#include <fstream>                                // std::ifstream, std::ofstream
#include <functional>                             // std::bind(), ::placeholders::_x, ::ref(), ::cref()
#include <limits>                                 // std::numeric_limits<>::max_digits10
#include <string>                                 // std::string
#include <tuple>                                  // std::tie()
#include <vector>                                 // std::vector<>

// Boost
#include <boost/archive/text_iarchive.hpp>        // boost::archive::text_iarchive
#include <boost/archive/text_oarchive.hpp>        // boost::archive::text_oarchive

// jScience
#include "jScience/linalg.hpp"                    // Matrix<>

// stats++
#include "statsxx/data.hpp"                       // read_datafile()
#include "statsxx/machine_learning/Ensemble.hpp"  // Ensemble
#include "statsxx/machine_learning/NeuralNet.hpp" // NEURAL_NET

// this
#include "init.hpp"                               // read_param_file()


static void f_train(
                    const Matrix<double> &X_in,
                    const Matrix<double> &X_out,
                    // -----
                    NEURAL_NET           &mlp,
                    // -----
                    const std::string     mlp_file,
                    // -----
                    const std::string     mlp_train_X_in_file,
                    const std::string     mlp_train_X_out_file,
                    // -----
                    const std::string     filename_mlp_train
                    );


static std::vector<double> f_evaluate(
                                      const std::vector<double> &input,
                                      // -----
                                      NEURAL_NET                &mlp
                                      );


//
// USAGE: ensemble_mlp filename_param
//
//     filename_param :: parameters filename
//
//     =========================================================
//
// NOTE: See the example input for parameters.
//
// NOTE: The ensemble training methods implemented are described (at a high level) at:
//
//     https://en.wikipedia.org/wiki/Ensemble_learning
//
// -----
//
// TODO: NOTE: There are no safety checks implemented.
// TODO: NOTE: ... Examples include:
//
//     - Cross checks with the MLP files
//          - classifications, architecture sizes, filenames (training, etc.)
//     - Checks whether classification/regression subroutines are being called against the type of ensemble that we have.
//
// TODO: NOTE: Related to the former-type checks, maybe could write Boost-type subroutines [like read_param_file()] to just read the relevant information.
// TODO: NOTE: ... This would much simplify the setting up these Ensemble calculations.
//
// -----
//
// TODO: NOTE: One type of ensemble method NOT implemented is "stacking".
//
// TODO: NOTE: Add polling? Majority vote?
//
// -----
//
// TODO: NOTE: It is difficult to determine how best to parallelize, because some functionalities (e.g., boosting) are sequential learning algorithms.
//
// TODO: NOTE: ... Parallelization may therefore be best on the stats++ side for the MLP, for example.
//
int main(
         int argc,
         char* argv[]
         )
{
    //=========================================================
    // INITIALIZATION
    //=========================================================

    std::string filename_param = std::string(argv[1]);

    // ----- PARAMETERS -----
    std::string ensemble_file;
    // =====
    // [mlp] BLOCK
    // -----
    std::string mlp_file;
    // -----
    std::string filename_mlp_create;
    // -----
    std::string mlp_train_X_in_file;
    std::string mlp_train_X_out_file;
    std::string filename_mlp_train;
    // =====
    // [create] BLOCK
    // -----
    bool        create;
    // -----
    int         nmlp;
    // -----
    bool        is_classif;
    // -----
    int         ni;
    int         no;
    // =====
    // [train] BLOCK
    // -----
    bool        train;
    // -----
    std::string train_X_in_file;
    std::string train_X_out_file;
    // -----
    std::string train_method;
    // =====
    // [optimize] BLOCK
    // -----
    bool        optimize;
    // -----
    std::string optimize_X_in_file;
    std::string optimize_X_out_file;
    // -----
    std::string optimize_method;
    // -----
    double      Ensemble_gen_err_eps;
    // -----
    int         BMC_N;
    // =====
    // [test] BLOCK
    // -----
    bool        test;
    // -----
    std::string test_X_in_file;
    std::string test_X_out_file;
    // -----
    double      cutoff;
    // =====
    std::string prefix;
    std::tie(
             ensemble_file,
             // =====
             mlp_file,
             // -----
             filename_mlp_create,
             // -----
             mlp_train_X_in_file,
             mlp_train_X_out_file,
             filename_mlp_train,
             // =====
             create,
             // -----
             nmlp,
             // -----
             is_classif,
             // -----
             ni,
             no,
             // =====
             train,
             // -----
             train_X_in_file,
             train_X_out_file,
             // -----
             train_method,
             // =====
             optimize,
             // -----
             optimize_X_in_file,
             optimize_X_out_file,
             // -----
             optimize_method,
             // -----
             Ensemble_gen_err_eps,
             // -----
             BMC_N,
             // =====
             test,
             // -----
             test_X_in_file,
             test_X_out_file,
             // -----
             cutoff,
             // =====
             prefix
             ) = read_param_file(
                                 filename_param
                                 );


    //=========================================================
    // CREATE (OR READ) ENSEMBLE
    //=========================================================
    //
    // NOTE: This always occurs (create or read-in Ensemble).
    //
    // TODO: NOTE: Probably in create, also set up individual directories to handle each learner.

    Ensemble<NEURAL_NET> ensemble;

    if( create )
    {
        ensemble = Ensemble<NEURAL_NET>(
                                        is_classif,
                                        // -----
                                        ni,
                                        no
                                        );

        for( int i = 0; i < nmlp; ++i )
        {
            // CREATE MLP
            std::string cmd = "mlp " + filename_mlp_create;
            std::system( cmd.c_str() );

            // READ MLP FROM ARCHIVE
            NEURAL_NET mlp;

            {
                std::ifstream ifs(mlp_file);

                boost::archive::text_iarchive ia(ifs);

                ia >> mlp;
            }

            // ADD LEARNER
            ensemble.add_learner(
                                 mlp
                                 );

            // CLEANUP
            cmd = "rm " + mlp_file;
            std::system( cmd.c_str() );
        }

        // SAVE ENSEMBLE TO ARCHIVE
        {
            std::ofstream ofs(ensemble_file);

            boost::archive::text_oarchive oa(ofs);

            oa << ensemble;
        }
    }
    else
    {
        // READ ENSEMBLE TO ARCHIVE
        {
            std::ifstream ifs(ensemble_file);

            boost::archive::text_iarchive ia(ifs);

            ia >> ensemble;
        }
    }

    //---------------------------------------------------------
    // ASSIGN Learner FUNCTIONS
    //---------------------------------------------------------

    for( auto i = 0; i < ensemble.learners.size(); ++i )
    {
        // TRAIN FUNCTION
        ensemble.learners[i].f_train = std::bind(
                                                 f_train,
                                                 // -----
                                                 std::placeholders::_1,
                                                 std::placeholders::_2,
                                                 // -----
                                                 std::ref(ensemble.learners[i]),
                                                 // -----
                                                 std::cref(mlp_file),
                                                 // -----
                                                 std::cref(mlp_train_X_in_file),
                                                 std::cref(mlp_train_X_out_file),
                                                 // -----
                                                 std::cref(filename_mlp_train)
                                                 );

        // EVALUATE FUNCTION
        ensemble.learners[i].f_evaluate = std::bind(
                                                    f_evaluate,
                                                    // -----
                                                    std::placeholders::_1,
                                                    // -----
                                                    std::ref(ensemble.learners[i])
                                                    );
    }

    //=========================================================
    // TRAIN
    //=========================================================

    if( train )
    {
        // READ TRAINING DATA
        Matrix<double> X_in  = read_datafile(
                                             train_X_in_file
                                             );

        Matrix<double> X_out = read_datafile(
                                             train_X_out_file
                                             );

        // TRAIN BY METHOD
        if( train_method == "bootstrap" )
        {
            ensemble.train_bootstrap(
                                     X_in,
                                     X_out
                                     );
        }
        else if( train_method == "boosting" )
        {
            ensemble.train_Real_AdaBoost(
                                         X_in,
                                         X_out
                                         );
        }
        // TODO:
        // else if( train_method == "bucket" )
        // else if( train_method == "stacking" )

        // SAVE ENSEMBLE TO ARCHIVE
        {
            std::ofstream ofs(ensemble_file);

            boost::archive::text_oarchive oa(ofs);

            oa << ensemble;
        }
    }

    //=========================================================
    // OPTIMIZE
    //=========================================================

    if( optimize )
    {
        // READ OPTIMIZE DATA
        Matrix<double> X_in  = read_datafile(
                                             optimize_X_in_file
                                             );

        Matrix<double> X_out = read_datafile(
                                             optimize_X_out_file
                                             );

        // OPTIMIZE BY METHOD
        if( optimize_method == "Ensemble_gen_err" )
        {
            ensemble.optimize_Ensemble_gen_err(
                                               X_in,
                                               X_out,
                                               // -----
                                               Ensemble_gen_err_eps,
                                               // -----
                                               prefix
                                               );
        }
        else if( optimize_method == "BMA" )
        {
            ensemble.optimize_BMA(
                                  X_in,
                                  X_out,
                                  // -----
                                  prefix
                                  );
        }
        else if( optimize_method == "BMC" )
        {
            ensemble.optimize_BMC(
                                  X_in,
                                  X_out,
                                  // -----
                                  BMC_N,
                                  // -----
                                  prefix
                                  );
        }
        // TODO:
        // else if( train_method == "bucket" )

        // OUTPUT WEIGHTS
        ensemble.output_weights(
                                prefix
                                );

        // SAVE ENSEMBLE TO ARCHIVE
        {
            std::ofstream ofs(ensemble_file);

            boost::archive::text_oarchive oa(ofs);

            oa << ensemble;
        }
    }

    //=========================================================
    // TEST
    //=========================================================

    if( test )
    {
        // TODO: NOTE: It is not clear yet if the following is more efficient and cleaner to keep testing contained (herein in Ensemble) or rely on mlp test.

        // READ TESTING DATA
        Matrix<double> X_in  = read_datafile(
                                             test_X_in_file
                                             );

        Matrix<double> X_out = read_datafile(
                                             test_X_out_file
                                             );

        // SETUP/OPEN OUTPUT FILES
        std::ofstream ofs(("./" + prefix + ".test.out.dat"));
        ofs.precision(std::numeric_limits<double>::max_digits10);

        std::ofstream ofs_c;

        if( ensemble.is_classif )
        {
            ofs_c.open(("./" + prefix + ".classifications.dat"));
        }

        // LOOP OVER TRAINING DATA
        for( auto i = 0; i < X_in.size(0); ++i )
        {
            std::vector<double> output = ensemble.evaluate(
                                                           X_in.row(i).std_vector()
                                                           );

            // OUTPUT
            for( auto j = 0; j < X_out.size(1); ++j )
            {
                ofs << output[j];

                if( j != (X_out.size(1)-1) )
                {
                    ofs << '\t';
                }
            }

            ofs << '\n';

            // CLASSIFICATION
            if( ensemble.is_classif )
            {
                // NOTE: Related to this: see the NOTEs in _MLP.
                if( output[0] > cutoff )
                {
                    ofs_c << "1" << '\n';
                }
                else
                {
                    ofs_c << "0" << '\n';
                }
            }
        }

        ofs.close();

        if( ensemble.is_classif )
        {
            ofs_c.close();
        }
    }





/*
    //=========================================================
    // CREATE (OR READ) MLP
    //=========================================================

    // NOTE: This always occurs (create or read-in MLP).
    NEURAL_NET mlp = create_MLP(
                                create,
                                // -----
                                create_dbn_file,
                                // -----
                                create_architecture,
                                create_af_type,
                                create_classif,
                                // -----
                                mlp_file
                                );

    //=========================================================
    // TRAIN
    //=========================================================

    if(train)
    {
        Matrix<double> X_in  = read_datafile(
                                             train_X_in_file
                                             );

        Matrix<double> X_out = read_datafile(
                                             train_X_out_file
                                             );

        mlp = train_MLP(
                        mlp,
                        // -----
                        train_method,
                        //------
                        train_nepoch_min,
                        train_nepoch_max,
                        // -----
                        train_lr,
                        train_lr_min,
                        train_lr_max,
                        // -----
                        train_momentum,
                        // -----
                        train_weight_penalty,
                        // -----
                        train_qrprop_u,
                        train_qrprop_d,
                        // -----
                        train_scg_lambda,
                        train_scg_sigma,
                        train_scg_convg_iterfrac,
                        train_scg_rk_tol,
                        // -----
                        train_EA_param,
                        // -----
                        X_in,
                        X_out,
                        // -----
                        train_npts_per_batch,
                        // -----
                        mlp_file
                        );
    }

    //=========================================================
    // TEST
    //=========================================================

    if(test)
    {
        std::vector<std::string> label = read_file<std::string>(
                                                                test_label_file
                                                                );

        Matrix<double> X_in = read_datafile(
                                            test_X_in_file
                                            );

        Matrix<double> X_out = read_datafile(
                                             test_X_out_file
                                             );

        test_MLP(
                 mlp,
                 // -----
                 label,
                 X_in,
                 X_out,
                 // -----
                 test_cutoff,
                 // -----
                 prefix
                 );
    }
*/

    return 0;
}


static void f_train(
                    const Matrix<double> &X_in,
                    const Matrix<double> &X_out,
                    // -----
                    NEURAL_NET           &mlp,
                    // -----
                    const std::string     mlp_file,
                    // -----
                    const std::string     mlp_train_X_in_file,
                    const std::string     mlp_train_X_out_file,
                    // -----
                    const std::string     filename_mlp_train
                    )
{
    // SAVE MLP TO ARCHIVE
    {
        std::ofstream ofs(mlp_file);

        boost::archive::text_oarchive oa(ofs);

        oa << mlp;
    }

    // WRITE DATA TO FILE
    std::ofstream ofs_in(mlp_train_X_in_file);
    std::ofstream ofs_out(mlp_train_X_out_file);

    for( auto j = 0; j < X_in.size(0); ++j )
    {
        for( auto k = 0; k < X_in.size(1); ++k )
        {
            ofs_in << X_in(j,k);

            if( k != (X_in.size(1)-1) )
            {
                ofs_in << " ";
            }
        }

        for( auto k = 0; k < X_out.size(1); ++k )
        {
            ofs_out << X_out(j,k);

            if( k != (X_out.size(1)-1) )
            {
                ofs_out << " ";
            }
        }

        if( j != (X_in.size(0)-1) )
        {
            ofs_in << '\n';
            ofs_out << '\n';
        }
    }

    ofs_in.close();
    ofs_out.close();

    // TRAIN
    std::string cmd = "mlp " + filename_mlp_train;
    std::system( cmd.c_str() );

    // READ (TRAINED) MLP FROM ARCHIVE
    {
        std::ifstream ifs(mlp_file);

        boost::archive::text_iarchive ia(ifs);

        ia >> mlp;
    }

    // CLEANUP
    cmd = "rm " + mlp_file;
    std::system( cmd.c_str() );

    cmd = "rm " + mlp_train_X_in_file;
    std::system( cmd.c_str() );

    cmd = "rm " + mlp_train_X_out_file;
    std::system( cmd.c_str() );
}


static std::vector<double> f_evaluate(
                                      const std::vector<double> &input,
                                      // -----
                                      NEURAL_NET                &mlp
                                      )
{
    std::vector<double> output;

    // TODO: NOTE: The following assumes non-recurrent data.
    std::vector<std::vector<double>> _input(
                                            1,
                                            input
                                            );

    mlp.evaluate(
                 _input,
                 output
                 );

    return output;
}
